# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Serve locally-built docs files.

There are essentially four components here:

1. A simple HTTP server that serves docs out of the build directory.

2. JavaScript that tells a doc to refresh itself when it receives a message
   that its source file has been changed. This is injected into each served
   page by #1.

3. A WebSocket server that pushes refresh messages to pages generated by #1
   using the WebSocket client included in #2.

4. A very simple file watcher that looks for changes in the built docs files
   and pushes messages about changed files to #3.
"""

import asyncio
import http.server
import io
import logging
from pathlib import Path
import socketserver
import threading
from tempfile import TemporaryFile
from typing import Callable

from watchdog.events import FileModifiedEvent, FileSystemEventHandler
from watchdog.observers import Observer
import websockets

_LOG = logging.getLogger('pw_docgen.docserver')


def _generate_script(path: str, host: str, port: str) -> bytes:
    """Generate the JavaScript to inject into served docs pages."""
    return f"""<script>
    var connection = null;
    var originFilePath = "{path}";

    function watchForReload() {{
        connection = new WebSocket("ws://{host}:{port}/");
        console.log("Connecting to WebSocket server...");

        connection.onopen = function () {{
            console.log("Connected to WebSocket server");
        }}

        connection.onerror = function () {{
            console.log("WebSocket connection disconnected or failed");
        }}

        connection.onmessage = function (message) {{
            if (message.data === originFilePath) {{
                window.location.reload(true);
            }}
        }}
    }}

    watchForReload();
</script>
</body>
""".encode(
        "utf-8"
    )


class OpenAndInjectScript:
    """A substitute for `open` that injects the refresh handler script.

    Instead of returning a handle to the file you asked for, it returns a
    handle to a temporary file which has been modified. That file will
    disappear as soon as it is `.close()`ed, but that has to be done manually;
    it will not close automatically when exiting scope.

    The instance stores the last path that was opened in `path`.
    """

    def __init__(self, host: str, port: str):
        self.path: str = ""
        self._host = host
        self._port = port

    def __call__(self, path: str, mode: str) -> io.BufferedReader:
        if 'b' not in mode:
            raise ValueError(
                "This should only be used to open files in binary mode."
            )

        content = (
            Path(path)
            .read_bytes()
            .replace(b"</body>", _generate_script(path, self._host, self._port))
        )

        tempfile = TemporaryFile('w+b')
        tempfile.write(content)
        # Let the caller read the file like it's just been opened.
        tempfile.seek(0)
        # Store the path that held the original file.
        self.path = path
        return tempfile  # type: ignore


def _docs_http_server(
    address: str, port: int, path: Path
) -> Callable[[], None]:
    """A simple file system-based HTTP server for built docs."""

    class DocsStaticRequestHandler(http.server.SimpleHTTPRequestHandler):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, directory=str(path), **kwargs)

        # Disable logs to stdout.
        def log_message(
            self, format: str, *args  # pylint: disable=redefined-builtin
        ) -> None:
            return

    def http_server_thread():
        with socketserver.TCPServer(
            (address, port), DocsStaticRequestHandler
        ) as httpd:
            httpd.serve_forever()

    return http_server_thread


class TaskFinishedException(Exception):
    """Indicates one task has completed successfully."""


class WebSocketConnectionClosedException(Exception):
    """Indicates that the WebSocket connection has been closed."""


class DocsWebsocketRequestHandler:
    """WebSocket server that sends page refresh info to clients.

    Push messages to the message queue to broadcast them to all connected
    clients.
    """

    def __init__(self, address: str = '127.0.0.1', port: int = 8765):
        self._address = address
        self._port = port
        self._connections = set()  # type: ignore
        self._messages: asyncio.Queue = asyncio.Queue()

    async def _register_connection(self, websocket) -> None:
        """Handle client connections and their event loops."""
        self._connections.add(websocket)
        _LOG.info("Client connection established: %s", websocket.id)

        while True:
            try:
                # Run all of these tasks simultaneously. We don't wait for *all*
                # of them to finish -- when one finishes, it raises one of the
                # flow control exceptions to determine what happens next.
                await asyncio.gather(
                    self._send_messages(),
                    self._drop_lost_connection(websocket),
                )
            except TaskFinishedException:
                _LOG.debug("One awaited task finished; iterating event loop.")
            except WebSocketConnectionClosedException:
                _LOG.debug("WebSocket connection closed; ending event loop.")
                return

    async def _drop_lost_connection(self, websocket) -> None:
        """Remove connections to clients with no heartbeat."""
        await asyncio.sleep(1)

        if websocket.closed:
            self._connections.remove(websocket)
            _LOG.info("Client connection dropped: %s", websocket.id)
            raise WebSocketConnectionClosedException

        _LOG.debug("Client connection heartbeat active: %s", websocket.id)
        raise TaskFinishedException

    async def _send_messages(self) -> None:
        """Send the messages in the message queue to all clients.

        Every page change is broadcast to every client. It is up to the client
        to determine whether the contents of a messages means it should refresh.
        This is a pretty easy determination to make though -- the client knows
        its own source file's path, so it just needs to check if the path in the
        message matches it.
        """
        message = await self._messages.get()
        websockets.broadcast(self._connections, message)  # type: ignore # pylint: disable=no-member
        _LOG.info("Sent to %d clients: %s", len(self._connections), message)
        raise TaskFinishedException

    async def _run(self) -> None:
        self._messages = asyncio.Queue()

        async with websockets.serve(  # type: ignore # pylint: disable=no-member
            self._register_connection, self._address, self._port
        ):
            await asyncio.Future()

    def push_message(self, message: str) -> None:
        """Push a message on to the message queue."""
        if len(self._connections) > 0:
            self._messages.put_nowait(message)
            _LOG.info("Pushed to message queue: %s", message)

    def run(self):
        """Run the WebSocket server."""
        asyncio.run(self._run())


class DocsFileChangeEventHandler(FileSystemEventHandler):
    """Handle watched built doc files events."""

    def __init__(self, ws_handler: DocsWebsocketRequestHandler) -> None:
        self._ws_handler = ws_handler

    def on_modified(self, event) -> None:
        if isinstance(event, FileModifiedEvent):
            # Push the path of the modified file to the WebSocket server's
            # message queue.
            path = Path(event.src_path).relative_to(Path.cwd())
            self._ws_handler.push_message(str(path))

        return super().on_modified(event)


class DocsFileChangeObserver(Observer):  # pylint: disable=too-many-ancestors
    """Watch for changes to built docs files."""

    def __init__(
        self, path: str, event_handler: FileSystemEventHandler, *args, **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.schedule(event_handler, path, recursive=True)
        _LOG.info("Watching build docs files at: %s", path)


def serve_docs(
    build_dir: Path,
    docs_path: Path,
    address: str = '127.0.0.1',
    port: int = 8000,
    ws_port: int = 8765,
) -> None:
    """Run the docs server.

    This actually spawns three threads, one each for the HTTP server, the
    WebSockets server, and the file watcher.
    """
    docs_path = build_dir.joinpath(docs_path.joinpath('html'))
    http_server_thread = _docs_http_server(address, port, docs_path)

    # The `http.server.SimpleHTTPRequestHandler.send_head` method loads the
    # HTML file from disk, generates and sends headers to the client, then
    # passes the file to the HTTP request handlers. We need to modify the file
    # in the middle of the process, and the only facility we have for doing that
    # is the somewhat distasteful patching of `open`.
    _open_and_inject_script = OpenAndInjectScript(address, str(ws_port))
    setattr(http.server, 'open', _open_and_inject_script)

    websocket_server = DocsWebsocketRequestHandler(address, ws_port)
    event_handler = DocsFileChangeEventHandler(websocket_server)

    threading.Thread(None, websocket_server.run, 'pw_docserver_ws').start()
    threading.Thread(None, http_server_thread, 'pw_docserver_http').start()
    DocsFileChangeObserver(str(docs_path), event_handler).start()

    _LOG.info('Serving docs at http://%s:%d', address, port)


async def ws_client(
    address: str = '127.0.0.1',
    port: int = 8765,
):
    """A simple WebSocket client, useful for testing.

    Run it like this: `asyncio.run(ws_client())`
    """
    async with websockets.connect(f"ws://{address}:{port}") as websocket:  # type: ignore # pylint: disable=no-member
        _LOG.info("Connection ID: %s", websocket.id)
        async for message in websocket:
            _LOG.info("Message received: %s", message)
